import yaml
import os
import spacy
import torch
from tqdm import tqdm
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DistributedSampler, DataLoader
from torch.cuda.amp import GradScaler
from transformer import make_model
import torch.multiprocessing as mp
import GPUtil
from torch.optim.lr_scheduler import LambdaLR
from training import LabelSmoothing, rate, TrainState
from training import run_epoch, Batch, SimpleLossCompute
from training import DummySchedular, DummyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
from utils.datasets import load_local_dataset
from torchtext.vocab import build_vocab_from_iterator
import torch.nn.functional as F
import utils.logger
from logging import Logger


def load_tokenizers():
    try:
        spacy_de = spacy.load('de_core_news_sm')
    except IOError:  # 如果未下载德语模型则报错
        print("The German model is not found. " +
              "Please download from https://spacy.io/usage")
        spacy_de = None

    try:
        spacy_en = spacy.load('en_core_web_sm')
    except IOError:  # 如果未下载英语模型则报错
        print("The English model is not found. " +
              "Please download from https://spacy.io/usage")
        spacy_en = None
    assert spacy_de is not None and spacy_en is not None
    return spacy_de, spacy_en


def tokenize(text, tokenizer):
    return [tok.text for tok in tokenizer.tokenizer(text)]


def yield_tokens(data_iter, tokenizer, index):
    for from_to_tuple in data_iter:  # data_iter 由多个元组组成即 (de, en), de为德语句子，en为英语句子, de、en不是编码器和解码器！！！
        yield tokenizer(from_to_tuple[index])


def build_vocabulary(spacy_de, spacy_en):
    def tokenizer_de(text):
        return [tok.text for tok in spacy_de.tokenizer(text)]

    def tokenizer_en(text):
        return [tok.text for tok in spacy_en.tokenizer(text)]

    # 德语 -> 英语   
    train, val, test = load_local_dataset().values()

    # min_freq 加入词表的单词的最小的频率
    print("Building German vocabulary...")
    vocab_src = build_vocab_from_iterator(
        iterator=yield_tokens(train + val + test, tokenizer_de, index=0),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"]
    )

    print("Building English vocabulary...")
    vocab_tgt = build_vocab_from_iterator(
        iterator=yield_tokens(train + val + test, tokenizer_en, index=1),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"]
    )

    vocab_src.set_default_index(vocab_src["<unk>"])
    vocab_tgt.set_default_index(vocab_tgt["<unk>"])
    return vocab_src, vocab_tgt


def load_vocab(spacy_de, spacy_en):
    if not os.path.exists("vocab.pt"):
        vocab_src, vocab_tgt = build_vocabulary(spacy_de, spacy_en)
        torch.save((vocab_src, vocab_tgt), "vocab.pt")
    else:
        vocab_src, vocab_tgt = torch.load("vocab.pt")
    print("Vocabulary loaded.")
    print("Source language vocabulary size:", len(vocab_src))
    print("Target language vocabulary size:", len(vocab_tgt))
    return vocab_src, vocab_tgt


def collate_batch(batch,
                  src_pipeline, tgt_pipeline,
                  src_vocab, tgt_vocab,
                  device, max_padding=128, pad_id=2):
    bs_id = torch.tensor([0], device=device)  # <d> token id 句子起始标识
    eos_id = torch.tensor([1], device=device)  # </d> token id 句子结束标识
    src_list, tgt_list = [], []
    for (_src, _tgt) in batch:
        processed_src = torch.cat(
            [
                bs_id,
                torch.tensor(
                    src_vocab(src_pipeline(_src)),
                    dtype=torch.int64,
                    device=device),
                eos_id
            ],
            dim=0)
        processed_tgt = torch.cat(
            [
                bs_id,
                torch.tensor(
                    tgt_vocab(tgt_pipeline(_tgt)),
                    dtype=torch.int64,
                    device=device),
                eos_id
            ],
            dim=0)
        processed_src = F.pad(
            processed_src,
            pad=(0, max_padding - len(processed_src)),
            value=pad_id
        )
        processed_tgt = F.pad(
            processed_tgt,
            pad=(0, max_padding - len(processed_tgt)),
            value=pad_id
        )
        src_list.append(processed_src)
        tgt_list.append(processed_tgt)
    src = torch.stack(src_list)
    tgt = torch.stack(tgt_list)
    return src, tgt


def create_dataloader(
        device,
        vocab_src,
        vocab_tgt,
        spacy_de,
        spacy_en,
        batch_size=1200,
        max_padding=128,
        is_distributed=False,
):
    def tokenize_de(text):
        return tokenize(text, spacy_de)

    def tokenize_en(text):
        return tokenize(text, spacy_en)

    def collate_fn(batch):
        return collate_batch(
            batch,
            tokenize_de,
            tokenize_en,
            vocab_src,
            vocab_tgt,
            device,
            max_padding=max_padding
        )

    train_iter, valid_iter, test_iter = load_local_dataset().values()
    train_iter_map = to_map_style_dataset(train_iter)
    # DistributedSampler 用于分布式训练，它会对数据集进行切分，每个进程只处理其中一部分数据
    train_sampler = (
        DistributedSampler(train_iter_map) if is_distributed else None
    )
    valid_iter_map = to_map_style_dataset(valid_iter)
    valid_sampler = (
        DistributedSampler(valid_iter_map) if is_distributed else None
    )

    train_dataloader = DataLoader(
        dataset=train_iter_map,
        batch_size=batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        collate_fn=collate_fn,
    )
    valid_dataloader = DataLoader(
        dataset=valid_iter_map,
        batch_size=batch_size,
        shuffle=False,
        sampler=valid_sampler,
        collate_fn=collate_fn,
    )
    return train_dataloader, valid_dataloader


def train_worker(
        gpu,
        ngpu_per_node,
        vocab_src,
        vocab_tgt,
        spacy_de,
        spacy_en,
        config,
        is_distributed=False,
        logger: Logger = None,
):
    logger.info(f"Train worker process using GPU: {gpu} for training.")

    torch.cuda.set_device(gpu)
    pad_idx = vocab_tgt["<blank>"]
    d_model = 512
    model = make_model(len(vocab_src), len(vocab_tgt), n=6)
    model.to("cuda")
    module = model
    is_main_process = True
    if is_distributed:
        torch.distributed.init_process_group(backend="nccl", init_method="env://", rank=gpu, world_size=ngpu_per_node)
        model = DDP(model, device_ids=[gpu], output_device=gpu)
        module = model.module
        is_main_process = gpu == 0

    criterion = LabelSmoothing(size=len(vocab_tgt), padding_idx=pad_idx, smoothing=0.1)
    criterion.cuda(gpu)

    train_dataloader, valid_dataloader = create_dataloader(
        gpu,
        vocab_src,
        vocab_tgt,
        spacy_de,
        spacy_en,
        batch_size=config["batch_size"] // ngpu_per_node,
        max_padding=config["max_padding"],
        is_distributed=is_distributed,
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=config["base_lr"], betas=(0.9, 0.98), eps=1e-9)

    lr_schedular = LambdaLR(
        optimizer,
        lr_lambda=lambda step: rate(step, d_model, factor=1, warmup=config["warmup"]),
    )

    scaler = GradScaler()

    train_state = TrainState()

    save_prefix = config["checkpoint_prefix"]
    if os.path.exists(save_prefix) is False:
        os.makedirs(save_prefix)

    for epoch in range(config["num_epochs"]):
        if is_distributed:
            train_dataloader.sampler.set_epoch(epoch)
            valid_dataloader.sampler.set_epoch(epoch)

        logger.info(f"[GPU:{gpu}] Epoch {epoch} ====")
        logger.info("Before training start, the TrainState is:")
        logger.info(f"step: {train_state.step}, samples: {train_state.samples}")
        logger.info(f"tokens: {train_state.tokens}, accum_step: {train_state.accum_step}")

        model.train()
        logger.info(f"[GPU:{gpu}] Epoch {epoch} Training ====")
        _, train_state = run_epoch(
            tqdm(
                iterable=(Batch(b[0], b[1], pad_idx) for b in train_dataloader),
                desc=f"Epoch {epoch} Training",
                total=len(train_dataloader)),
            model,
            SimpleLossCompute(criterion),
            optimizer,
            lr_schedular,
            scaler,
            mode="train+log",
            accum_iter=config["accum_iter"],
            train_state=train_state,
            logger=logger
        )

        GPUtil.showUtilization()
        if is_main_process:
            file_path = os.path.join(config["checkpoint_prefix"], "%.2d.pt" % epoch)
            torch.save(module.state_dict(), file_path)
        torch.cuda.empty_cache()

        model.eval()
        logger.info(f"[GPU:{gpu}] Epoch {epoch} Validation ====")
        sloss, _ = run_epoch(
            tqdm(
                iterable=(Batch(b[0], b[1], pad_idx) for b in valid_dataloader),
                desc=f"Epoch {epoch} Validation",
                total=len(valid_dataloader)),
            model,
            SimpleLossCompute(criterion),
            DummyOptimizer(),
            DummySchedular(),
            scaler,
            mode='eval'
        )
        logger.info(f"Epoch({epoch}): scaled loss: {sloss}")
        torch.cuda.empty_cache()

    if is_main_process:
        file_path = os.path.join(config["checkpoint_prefix"], "transformer_model.pt")
        torch.save(module.state_dict(), file_path)


def train_distributed_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config, logger: Logger = None):
    ngpus = torch.cuda.device_count()
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12356"
    logger.info(f"Number of GPUs detected: {ngpus}")
    logger.info("Spawning training processes ...")
    mp.spawn(
        train_worker,
        nprocs=ngpus,
        args=(ngpus, vocab_src, vocab_tgt, spacy_de, spacy_en, config, True, logger),
    )


def train_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config, logger: Logger = None):
    if config["distributed"]:
        train_distributed_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config, logger)
    else:
        train_worker(0, 1, vocab_src, vocab_tgt, spacy_de, spacy_en, config, False, logger)


if __name__ == "__main__":
    spacy_de, spacy_en = load_tokenizers()
    vocab_src, vocab_tgt = load_vocab(spacy_de, spacy_en)

    with open("config.yaml", "r") as f:
        config = yaml.safe_load(f)
    logger = utils.logger.get_logger(save_prefix=config['log_prefix'])
    logger.info(str(config))
    train_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config, logger)
